{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n",
        "\n",
        "建议直接使用google colab notebook打开本教程，可以快速下载相关数据集和模型。\n",
        "如果您正在google的colab中打开这个notebook，您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。"
      ],
      "metadata": {
        "id": "X4cRE8IbIrIV"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "! pip install datasets transformers rouge-score nltk"
      ],
      "outputs": [],
      "metadata": {
        "id": "MOsHUjgdIrIW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "分布式训练请查看 [这里](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)."
      ],
      "metadata": {
        "id": "HFASsisvIrIb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "微调transformer模型解决摘要生成任务"
      ],
      "metadata": {
        "id": "rEJBSTyZIrIb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "在本notebook中，我们将展示如何微调 [🤗 Transformers](https://github.com/huggingface/transformers)中的预训练模型来解决摘要生成任务。我们使用[XSum dataset](https://arxiv.org/pdf/1808.08745.pdf)数据集。这个数据集包含了BBC的文章和一句对应的摘要。下面是一个例子：\n",
        "\n",
        "![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/master/examples/images/summarization.png?raw=1)\n",
        "\n",
        "对于摘要生成任务，我们将展示如何使用简单的加载数据集，同时针对相应的仍无使用transformer中的Trainer接口对模型进行微调。"
      ],
      "metadata": {
        "id": "kTCFado4IrIc"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "model_checkpoint = \"t5-small\""
      ],
      "outputs": [],
      "metadata": {
        "id": "iO8WOo1hdJT0"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "只要预训练的transformer模型包含seq2seq结构的head层，那么本notebook理论上可以使用各种各样的transformer模型，解决任何摘要生成任务。这里，我们使用[`t5-small`](https://huggingface.co/t5-small)模型checkpoint。\n"
      ],
      "metadata": {
        "id": "4RRkXuteIrIh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 加载数据"
      ],
      "metadata": {
        "id": "whPRbBNbIrIl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "我们将会使用[🤗 Datasets](https://github.com/huggingface/datasets)库来加载数据和对应的评测方式。数据加载和评测方式加载只需要简单使用load_dataset和load_metric即可。\n"
      ],
      "metadata": {
        "id": "W7QYTpxXIrIl"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "from datasets import load_dataset, load_metric\r\n",
        "\r\n",
        "raw_datasets = load_dataset(\"xsum\")\r\n",
        "metric = load_metric(\"rouge\")"
      ],
      "outputs": [],
      "metadata": {
        "id": "IreSlFmlIrIm"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "除此之外，你也可以从我们提供的[链接](https://gas.graviti.cn/dataset/datawhale/Xsum)下载数据并解压，将解压后的json文件和`bbc-summary-data`文件夹复制到`docs/篇章4-使用Transformers解决NLP任务/datasets/xsum`目录下，然后用下面的代码进行加载。同理，也可以使用我们提供的脚本加载评测方式`rouge`。"
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "import os\r\n",
        "\r\n",
        "data_path = './dataset/xsum/'\r\n",
        "path = os.path.join(data_path, 'xsum.py')\r\n",
        "cache_dir = os.path.join(data_path, 'cache')\r\n",
        "data_files = {'data': data_path}\r\n",
        "dataset = load_dataset(path, data_files=data_files, cache_dir=cache_dir)\r\n",
        "\r\n",
        "metric_path = './dataset/metrics/rouge.py'\r\n",
        "metric = load_metric(metric_path)"
      ],
      "outputs": [],
      "metadata": {}
    },
    {
      "cell_type": "markdown",
      "source": [
        "这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集，只需要使用对应的key（train，validation，test）即可得到相应的数据。"
      ],
      "metadata": {
        "id": "RzfPtOMoIrIu"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "raw_datasets"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['document', 'summary', 'id'],\n",
              "        num_rows: 204045\n",
              "    })\n",
              "    validation: Dataset({\n",
              "        features: ['document', 'summary', 'id'],\n",
              "        num_rows: 11332\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['document', 'summary', 'id'],\n",
              "        num_rows: 11334\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ],
      "metadata": {
        "id": "GWiVUF0jIrIv",
        "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "给定一个数据切分的key（train、validation或者test）和下标即可查看数据："
      ],
      "metadata": {
        "id": "u3EtYfeHIrIz"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "raw_datasets[\"train\"][0]"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'document': 'Recent reports have linked some France-based players with returns to Wales.\\n\"I\\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\\'d rather spend that money on keeping players in Wales,\" said Davies.\\nThe WRU provides £2m to the fund and £1.3m comes from the regions.\\nFormer Wales and British and Irish Lions fly-half Davies became WRU chairman on Tuesday 21 October, succeeding deposed David Pickering following governing body elections.\\nHe is now serving a notice period to leave his role as Newport Gwent Dragons chief executive after being voted on to the WRU board in September.\\nDavies was among the leading figures among Dragons, Ospreys, Scarlets and Cardiff Blues officials who were embroiled in a protracted dispute with the WRU that ended in a £60m deal in August this year.\\nIn the wake of that deal being done, Davies said the £3.3m should be spent on ensuring current Wales-based stars remain there.\\nIn recent weeks, Racing Metro flanker Dan Lydiate was linked with returning to Wales.\\nLikewise the Paris club\\'s scrum-half Mike Phillips and centre Jamie Roberts were also touted for possible returns.\\nWales coach Warren Gatland has said: \"We haven\\'t instigated contact with the players.\\n\"But we are aware that one or two of them are keen to return to Wales sooner rather than later.\"\\nSpeaking to Scrum V on BBC Radio Wales, Davies re-iterated his stance, saying keeping players such as Scarlets full-back Liam Williams and Ospreys flanker Justin Tipuric in Wales should take precedence.\\n\"It\\'s obviously a limited amount of money [available]. The union are contributing 60% of that contract and the regions are putting £1.3m in.\\n\"So it\\'s a total pot of just over £3m and if you look at the sorts of salaries that the... guys... have been tempted to go overseas for [are] significant amounts of money.\\n\"So if we were to bring the players back, we\\'d probably get five or six players.\\n\"And I\\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\\'d rather spend that money on keeping players in Wales.\\n\"There are players coming out of contract, perhaps in the next year or so… you\\'re looking at your Liam Williams\\' of the world; Justin Tipuric for example - we need to keep these guys in Wales.\\n\"We actually want them there. They are the ones who are going to impress the young kids, for example.\\n\"They are the sort of heroes that our young kids want to emulate.\\n\"So I would start off [by saying] with the limited pot of money, we have to retain players in Wales.\\n\"Now, if that can be done and there\\'s some spare monies available at the end, yes, let\\'s look to bring players back.\\n\"But it\\'s a cruel world, isn\\'t it?\\n\"It\\'s fine to take the buck and go, but great if you can get them back as well, provided there\\'s enough money.\"\\nBritish and Irish Lions centre Roberts has insisted he will see out his Racing Metro contract.\\nHe and Phillips also earlier dismissed the idea of leaving Paris.\\nRoberts also admitted being hurt by comments in French Newspaper L\\'Equipe attributed to Racing Coach Laurent Labit questioning their effectiveness.\\nCentre Roberts and flanker Lydiate joined Racing ahead of the 2013-14 season while scrum-half Phillips moved there in December 2013 after being dismissed for disciplinary reasons by former club Bayonne.',\n",
              " 'id': '29750031',\n",
              " 'summary': 'New Welsh Rugby Union chairman Gareth Davies believes a joint £3.3m WRU-regions fund should be used to retain home-based talent such as Liam Williams, not bring back exiled stars.'}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ],
      "metadata": {
        "id": "X6HrpprwIrIz",
        "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "为了能够进一步理解数据长什么样子，下面的函数将从数据集里随机选择几个例子进行展示。"
      ],
      "metadata": {
        "id": "WHUmphG3IrI3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "import datasets\n",
        "import random\n",
        "import pandas as pd\n",
        "from IPython.display import display, HTML\n",
        "\n",
        "def show_random_elements(dataset, num_examples=2):\n",
        "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
        "    picks = []\n",
        "    for _ in range(num_examples):\n",
        "        pick = random.randint(0, len(dataset)-1)\n",
        "        while pick in picks:\n",
        "            pick = random.randint(0, len(dataset)-1)\n",
        "        picks.append(pick)\n",
        "    \n",
        "    df = pd.DataFrame(dataset[picks])\n",
        "    for column, typ in dataset.features.items():\n",
        "        if isinstance(typ, datasets.ClassLabel):\n",
        "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
        "    display(HTML(df.to_html()))"
      ],
      "outputs": [],
      "metadata": {
        "id": "i3j8APAoIrI3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "show_random_elements(raw_datasets[\"train\"])"
      ],
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>document</th>\n",
              "      <th>id</th>\n",
              "      <th>summary</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Media playback is not supported on this device\\nThe 31-year-old Portugal captain has the chance to win the first international title of his glittering career when his side take on the tournament hosts at Stade de France.\\n\"Everyone in the country will be against him, but he will thrive off that hostility, and off their fear,\" Ferdinand told BBC Sport.\\n\"The French know Cristiano is the player capable of destroying their dream because he has produced magic moments in huge matches right through his career.\\n\"Part of the reason he is a superstar is because he is not fazed by the big occasion - quite the opposite in fact.\\n\"Superstars like him relish these situations - the pressure that goes with it brings the best out of him, when other players falter.\"\\nRonaldo helped Portugal reach the final when they hosted Euro 2004, but was left crying on the pitch after defeat by Greece, and they have not reached this stage of a major tournament since.\\n\"He was only 19 then, so just a kid,\" added Ferdinand, who played alongside Ronaldo at Old Trafford between 2003 and 2009 and is working in France as a BBC pundit.\\n\"I remember his reaction but I think he was a bit too young to take it all in. At that age, you expect you will get back to another final soon to rectify what happened but he has had to wait 12 years for his chance.\\n\"I think he is very aware this is his last opportunity to win something with his country and, knowing him like I do, that makes him even more dangerous. He will be so desperate not to miss out again.\\n\"Cristiano has produced great performances for Portugal when it matters before, for example his hat-trick against Sweden in the play-off for the 2014 World Cup.\\n\"So France will know that it is not just in a Real Madrid or Manchester United shirt that he is capable of great things, especially because it was his moment of brilliance that helped decide Portugal's semi-final against Wales.\"\\nRonaldo is now the joint highest scorer in European Championship history with nine goals, and holds the record for most headed goals with five.\\nTwo of them have come in France, most notably when he soared to nod Portugal in front against Wales, but Ferdinand says that part of Ronaldo's game is nothing new.\\n\"He has always been amazing on the ball but even when he first joined United in 2003 he was great in the air too,\" he said.\\n\"Early in his career it was a part of his game that was quite undervalued but he always scored a lot of headers and, the way he does it, he is the closest thing in football to basketball legend Michael Jordan.\\n\"The way he jumps and hangs in the air is the same as Jordan and he has got the ability to stay up there, assess the situation and then put the ball where he wants to, with power.\\n\"I will always remember the header he scored for United away at Roma in the Champions League quarter-finals in April 2008.\\n\"He more or less jumped on the edge of the box to meet a cross that Paul Scholes put over but he met the ball a good way inside the area. If you watch TV footage of that game, he just appears from nowhere and smashes it into the bottom corner.\\n\"Just like his goal against Wales, it was an unbelievable jump and he generated incredible power. I was on the pitch that night, and it was amazing to see.\\n\"Cristiano's heading ability will be a huge threat in the final too, along with Nani - another former United team-mate of mine.\\n\"France have struggled to defend crosses for most of the tournament and, although they were better at it in the semi-final, Germany did not have anyone to aim at in the box.\"\\nEighty-six players at Euro 2016 have completed more dribbles than the three Ronaldo has managed in six matches.\\n\"He was always well known for his brilliant runs forward but his game is not about that any more,\" said Ferdinand, who was in the same United team as Ronaldo when he scored 42 goals in the 2007-08 season.\\n\"Before, he used to exert a lot of energy trying to take people on from deep areas, running at goal from 30 or 40 yards, or even further out.\\n\"Now, he is very clever in where he tries to receive the ball. It has to be in good positions and, when he gets it, he finds a yard of space and hits it - either a shot or a cross.\\n\"Part of the reason he has been able to reinvent himself is because of how hard he works - right from the start of his career, when we were together at Old Trafford, he was totally committed to improving every part of his game.\\n\"But to be able to re-evaluate his game and change it is also down to his football intelligence.\\n\"Clearly he is clever - you do not score 50 goals a season, six seasons running, for Real Madrid if you are not.\\n\"But his extra intelligence has allowed him to evolve as a player, understand his body, where it can take him and how often.\\n\"He has become a much more efficient player, but is still an extremely effective one.\"\\nSunday offers Ronaldo the opportunity for personal glory too, with the chance to get one over Lionel Messi in the battle to be viewed as the best player in the world.\\nMessi has never won a major tournament for Argentina and announced his international retirement last month after they lost to Chile in the final of the Copa America.\\n\"If Ronaldo wins the European Championship, it will be massive for him,\" said Ferdinand, 37.\\n\"I don't think it will give him the edge over Messi in terms of who deserves the accolade of the best in the world, but it is a huge achievement.\\n\"And it will matter to both of them, because there is a definite battle between them in their own minds about who has done what for club and country.\\n\"It is far from a given that Ronaldo will manage it, of course. France are looking very good and they have a game-changer in Antoine Griezmann.\\n\"Even if Ronaldo is at his best, it is a difficult ask for them and I think they are going to have to play ugly, like they have done all the way through the tournament, to win.\\n\"I have known him a long time and I would love to see him do it, but it is awkward for me because I have friends in both camps - Nani and Cristiano for Portugal, and Paul Pogba and Patrice Evra for France.\\n\"So I am not really bothered about the result, I just want to see a good game. I would love to see Ronaldo and Griezmann perform to their potential and finish off this tournament on a high note.\"</td>\n",
              "      <td>36749142</td>\n",
              "      <td>Cristiano Ronaldo will relish taking on the whole of France in Sunday's Euro 2016 final, according to his former Manchester United team-mate Rio Ferdinand.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Media playback is unsupported on your device\\n18 December 2014 Last updated at 10:28 GMT\\nMalaysia has successfully tackled poverty over the last four decades by drawing on its rich natural resources.\\nAccording to the World Bank, some 49% of Malaysians in 1970 were extremely poor, and that figure has been reduced to 1% today. However, the government's next challenge is to help the lower income group to move up to the middle class, the bank says.\\nUlrich Zahau, the World Bank's Southeast Asia director, spoke to the BBC's Jennifer Pak.</td>\n",
              "      <td>30530533</td>\n",
              "      <td>In Malaysia the \"aspirational\" low-income part of the population is helping to drive economic growth through consumption, according to the World Bank.</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 791
        },
        "id": "SZy5tRB_IrI7",
        "outputId": "17194536-108f-4dc2-f5c2-dccf93c0164e"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "metric是[`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)类的一个实例，查看metric和使用的例子:"
      ],
      "metadata": {
        "id": "lnjDIuQ3IrI-"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "metric"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Metric(name: \"rouge\", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: \"\"\"\n",
              "Calculates average rouge scores for a list of hypotheses and references\n",
              "Args:\n",
              "    predictions: list of predictions to score. Each predictions\n",
              "        should be a string with tokens separated by spaces.\n",
              "    references: list of reference for each prediction. Each\n",
              "        reference should be a string with tokens separated by spaces.\n",
              "    rouge_types: A list of rouge types to calculate.\n",
              "        Valid names:\n",
              "        `\"rouge{n}\"` (e.g. `\"rouge1\"`, `\"rouge2\"`) where: {n} is the n-gram based scoring,\n",
              "        `\"rougeL\"`: Longest common subsequence based scoring.\n",
              "        `\"rougeLSum\"`: rougeLsum splits text using `\"\n",
              "\"`.\n",
              "        See details in https://github.com/huggingface/datasets/issues/617\n",
              "    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.\n",
              "    use_agregator: Return aggregates if this is set to True\n",
              "Returns:\n",
              "    rouge1: rouge_1 (precision, recall, f1),\n",
              "    rouge2: rouge_2 (precision, recall, f1),\n",
              "    rougeL: rouge_l (precision, recall, f1),\n",
              "    rougeLsum: rouge_lsum (precision, recall, f1)\n",
              "Examples:\n",
              "\n",
              "    >>> rouge = datasets.load_metric('rouge')\n",
              "    >>> predictions = [\"hello there\", \"general kenobi\"]\n",
              "    >>> references = [\"hello there\", \"general kenobi\"]\n",
              "    >>> results = rouge.compute(predictions=predictions, references=references)\n",
              "    >>> print(list(results.keys()))\n",
              "    ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']\n",
              "    >>> print(results[\"rouge1\"])\n",
              "    AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))\n",
              "    >>> print(results[\"rouge1\"].mid.fmeasure)\n",
              "    1.0\n",
              "\"\"\", stored examples: 0)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5o4rUteaIrI_",
        "outputId": "63ba0150-f748-4cc0-fda1-dada85bcdf79"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "我们使用`compute`方法来对比predictions和labels，从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子："
      ],
      "metadata": {
        "id": "jAWdqcUBIrJC"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "fake_preds = [\"hello there\", \"general kenobi\"]\n",
        "fake_labels = [\"hello there\", \"general kenobi\"]\n",
        "metric.compute(predictions=fake_preds, references=fake_labels)"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n",
              " 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n",
              " 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n",
              " 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6XN1Rq0aIrJC",
        "outputId": "56fc8845-3d36-4490-b281-54096bf82ad4"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 数据预处理"
      ],
      "metadata": {
        "id": "n9qywopnIrJH"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "在将数据喂入模型之前，我们需要对数据进行预处理。预处理的工具叫Tokenizer。Tokenizer首先对输入进行tokenize，然后将tokens转化为预模型中需要对应的token ID，再转化为模型需要的输入格式。\n",
        "\n",
        "为了达到数据预处理的目的，我们使用AutoTokenizer.from_pretrained方法实例化我们的tokenizer，这样可以确保：\n",
        "\n",
        "- 我们得到一个与预训练模型一一对应的tokenizer。\n",
        "- 使用指定的模型checkpoint对应的tokenizer的时候，我们也下载了模型需要的词表库vocabulary，准确来说是tokens vocabulary。\n",
        "\n",
        "\n",
        "这个被下载的tokens vocabulary会被缓存起来，从而再次使用的时候不会重新下载。"
      ],
      "metadata": {
        "id": "YVx71GdAIrJH"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "from transformers import AutoTokenizer\n",
        "    \n",
        "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
      ],
      "outputs": [],
      "metadata": {
        "id": "eXNLu_-nIrJI"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library."
      ],
      "metadata": {
        "id": "Vl6IidfdIrJK"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "tokenizer既可以对单个文本进行预处理，也可以对一对文本进行预处理，tokenizer预处理后得到的数据满足预训练模型输入格式"
      ],
      "metadata": {
        "id": "rowT4iCLIrJK"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "tokenizer(\"Hello, this one sentence!\")"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ],
      "metadata": {
        "id": "a5hBlsrHIrJL",
        "outputId": "acdaa98a-a8cd-4a20-89b8-cc26437bbe90"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "上面看到的token IDs也就是input_ids一般来说随着预训练模型名字的不同而有所不同。原因是不同的预训练模型在预训练的时候设定了不同的规则。但只要tokenizer和model的名字一致，那么tokenizer预处理的输入格式就会满足model需求的。关于预处理更多内容参考[这个教程](https://huggingface.co/transformers/preprocessing.html)\n",
        "\n",
        "除了可以tokenize一句话，我们也可以tokenize一个list的句子。"
      ],
      "metadata": {
        "id": "qo_0B1M2IrJM"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"])"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0YFUpXROdJT6",
        "outputId": "3659a781-f0d5-48db-9331-34df0452a5a8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "注意：为了给模型准备好翻译的targets，我们使用`as_target_tokenizer`来控制targets所对应的特殊token："
      ],
      "metadata": {
        "id": "skFBG0iBdJT7"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "with tokenizer.as_target_tokenizer():\n",
        "    print(tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"]))"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}\n"
          ]
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "E2VZkR_XdJT7",
        "outputId": "12732226-b243-400e-fdfc-7b137c884438"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "如果您使用的是T5预训练模型的checkpoints，需要对特殊的前缀进行检查。T5使用特殊的前缀来告诉模型具体要做的任务，具体前缀例子如下：\n"
      ],
      "metadata": {
        "id": "2C0hcmp9IrJQ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "if model_checkpoint in [\"t5-small\", \"t5-base\", \"t5-larg\", \"t5-3b\", \"t5-11b\"]:\n",
        "    prefix = \"summarize: \"\n",
        "else:\n",
        "    prefix = \"\""
      ],
      "outputs": [],
      "metadata": {
        "id": "LFntjGh8dJT7"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "现在我们可以把所有内容放在一起组成我们的预处理函数了。我们对样本进行预处理的时候，我们还会`truncation=True`这个参数来确保我们超长的句子被截断。默认情况下，对与比较短的句子我们会自动padding。"
      ],
      "metadata": {
        "id": "rZMGxnr5dJT8"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "max_input_length = 1024\n",
        "max_target_length = 128\n",
        "\n",
        "def preprocess_function(examples):\n",
        "    inputs = [prefix + doc for doc in examples[\"document\"]]\n",
        "    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
        "\n",
        "    # Setup the tokenizer for targets\n",
        "    with tokenizer.as_target_tokenizer():\n",
        "        labels = tokenizer(examples[\"summary\"], max_length=max_target_length, truncation=True)\n",
        "\n",
        "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
        "    return model_inputs"
      ],
      "outputs": [],
      "metadata": {
        "id": "vc0BSBLIIrJQ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "以上的预处理函数可以处理一个样本，也可以处理多个样本exapmles。如果是处理多个样本，则返回的是多个样本被预处理之后的结果list。"
      ],
      "metadata": {
        "id": "0lm8ozrJIrJR"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "preprocess_function(raw_datasets['train'][:2])"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input_ids': [[21603, 10, 17716, 2279, 43, 5229, 128, 1410, 18, 390, 1508, 28, 5146, 12, 10256, 5, 96, 196, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 976, 243, 28571, 5, 37, 549, 8503, 795, 17586, 51, 12, 8, 3069, 11, 3996, 13606, 51, 639, 45, 8, 6266, 5, 18263, 10256, 11, 2390, 11, 7262, 10371, 7, 3971, 18, 17114, 28571, 1632, 549, 8503, 13404, 30, 2818, 1401, 1797, 6, 7229, 53, 20, 12151, 1955, 8356, 49, 53, 826, 3, 19585, 643, 9768, 5, 216, 19, 230, 3122, 3, 9, 2103, 1059, 12, 1175, 112, 1075, 38, 24260, 350, 16103, 10282, 7, 5752, 4297, 227, 271, 3, 11060, 30, 12, 8, 549, 8503, 1476, 16, 1600, 5, 28571, 47, 859, 8, 1374, 5638, 859, 10282, 7, 6, 411, 7, 2026, 63, 7, 6, 14586, 7677, 11, 26911, 2419, 7, 4298, 113, 130, 10960, 52, 26786, 16, 3, 9, 813, 11674, 11044, 28, 8, 549, 8503, 24, 3492, 16, 3, 9, 3996, 3328, 51, 1154, 16, 1660, 48, 215, 5, 86, 8, 7178, 13, 24, 1154, 271, 612, 6, 28571, 243, 8, 3996, 19660, 51, 225, 36, 1869, 30, 3, 5833, 750, 10256, 18, 390, 4811, 2367, 132, 5, 86, 1100, 1274, 6, 16046, 10730, 24397, 49, 2744, 31914, 17, 15, 47, 5229, 28, 7646, 12, 10256, 5, 3, 21322, 8, 1919, 1886, 31, 7, 14667, 440, 18, 17114, 4794, 16202, 7, 11, 2050, 17845, 2715, 7, 130, 92, 2633, 26, 21, 487, 5146, 5, 10256, 3763, 16700, 2776, 17, 40, 232, 65, 243, 10, 96, 1326, 43, 29, 31, 17, 16, 7, 2880, 920, 574, 28, 8, 1508, 5, 96, 11836, 62, 33, 2718, 24, 80, 42, 192, 13, 135, 33, 9805, 12, 1205, 12, 10256, 14159, 1066, 145, 865, 535, 14734, 12, 4712, 2781, 584, 30, 9938, 5061, 10256, 6, 28571, 3, 60, 18, 155, 15, 4094, 112, 3, 8389, 6, 2145, 2627, 1508, 224, 38, 14586, 7677, 423, 18, 1549, 1414, 265, 6060, 11, 411, 7, 2026, 63, 7, 24397, 49, 12446, 2262, 3791, 447, 16, 10256, 225, 240, 20799, 1433, 5, 96, 196, 17, 31, 7, 6865, 3, 9, 1643, 866, 13, 540, 784, 28843, 4275, 37, 7021, 33, 12932, 15436, 13, 24, 1696, 11, 8, 6266, 33, 3, 3131, 3996, 13606, 51, 16, 5, 96, 5231, 34, 31, 7, 3, 9, 792, 815, 13, 131, 147, 23395, 51, 11, 3, 99, 25, 320, 44, 8, 10549, 13, 21331, 24, 8, 233, 3413, 233, 43, 118, 3, 22765, 12, 281, 10055, 21, 784, 355, 908, 1516, 6201, 13, 540, 5, 96, 5231, 3, 99, 62, 130, 12, 830, 8, 1508, 223, 6, 62, 31, 26, 1077, 129, 874, 42, 1296, 1508, 5, 96, 7175, 27, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 5, 96, 7238, 33, 1508, 1107, 91, 13, 1696, 6, 2361, 16, 8, 416, 215, 42, 78, 233, 25, 31, 60, 479, 44, 39, 1414, 265, 6060, 31, 13, 8, 296, 117, 12446, 2262, 3791, 447, 21, 677, 3, 18, 62, 174, 12, 453, 175, 3413, 16, 10256, 5, 96, 1326, 700, 241, 135, 132, 5, 328, 33, 8, 2102, 113, 33, 352, 12, 18514, 8, 1021, 1082, 6, 21, 677, 5, 96, 10273, 33, 8, 1843, 13, 17736, 24, 69, 1021, 1082, 241, 12, 29953, 5, 96, 5231, 27, 133, 456, 326, 784, 969, 2145, 908, 28, 8, 1643, 815, 13, 540, 6, 62, 43, 12, 7365, 1508, 16, 10256, 5, 96, 17527, 6, 3, 99, 24, 54, 36, 612, 11, 132, 31, 7, 128, 8179, 3, 26413, 347, 44, 8, 414, 6, 4273, 6, 752, 31, 7, 320, 12, 830, 1508, 223, 5, 96, 11836, 34, 31, 7, 3, 9, 23958, 296, 6, 19, 29, 31, 17, 34, 58, 96, 196, 17, 31, 7, 1399, 12, 240, 8, 3, 13863, 11, 281, 6, 68, 248, 3, 99, 25, 54, 129, 135, 223, 38, 168, 6, 937, 132, 31, 7, 631, 540, 535, 2390, 11, 7262, 10371, 7, 2050, 2715, 7, 65, 16, 15777, 3, 88, 56, 217, 91, 112, 16046, 10730, 1696, 5, 216, 11, 16202, 7, 92, 2283, 19664, 8, 800, 13, 3140, 1919, 5, 2715, 7, 92, 10246, 271, 4781, 57, 2622, 16, 2379, 29494, 301, 31, 427, 23067, 15, 3, 20923, 12, 16046, 9493, 9906, 17, 325, 2360, 822, 53, 70, 9570, 5, 2969, 2715, 7, 11, 24397, 49, 31914, 17, 15, 3311, 16046, 2177, 13, 8, 2038, 11590, 774, 298, 14667, 440, 18, 17114, 16202, 7, 2301, 132, 16, 1882, 2038, 227, 271, 19664, 21, 3, 15471, 2081, 57, 1798, 1886, 2474, 5993, 5, 1], [21603, 10, 6788, 19758, 7, 2273, 130, 718, 91, 12, 1154, 28, 3, 9, 6220, 2642, 44, 8, 6036, 30, 8, 368, 3540, 986, 7, 2409, 30, 1701, 706, 5, 2409, 7, 130, 16645, 326, 11, 2117, 12355, 1054, 38, 3, 9, 6478, 16813, 47, 4006, 91, 5, 37, 12787, 6, 261, 57, 1932, 27874, 5220, 31195, 3230, 6, 43, 118, 7774, 3, 9, 381, 13, 648, 5, 1377, 1310, 6, 17602, 6417, 6032, 130, 4006, 91, 30, 8, 6036, 30, 12096, 8348, 16, 1186, 11, 932, 5, 37, 6032, 1553, 826, 3, 9, 27874, 896, 2063, 2902, 16, 1882, 1673, 18395, 53, 8, 7070, 13, 8, 7021, 5692, 44, 8, 896, 2501, 5, 1193, 1778, 29, 53, 8, 1251, 3534, 9, 226, 6, 11529, 283, 4569, 4409, 5225, 8692, 243, 10, 96, 196, 17, 19, 3, 9, 2261, 5415, 21, 8, 415, 616, 6, 34, 4110, 2261, 17879, 6, 34, 10762, 151, 31, 7, 1342, 44, 1020, 6, 34, 54, 1709, 3583, 364, 7232, 8, 616, 5, 96, 19494, 62, 174, 151, 28, 251, 12, 698, 24, 28, 8, 2095, 16, 455, 21, 135, 12, 103, 70, 613, 11, 830, 175, 151, 12, 4831, 535, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[368, 22982, 24895, 3545, 13404, 3121, 15, 189, 28571, 7228, 3, 9, 4494, 3996, 19660, 51, 549, 8503, 18, 18145, 7, 3069, 225, 36, 261, 12, 7365, 234, 18, 390, 3683, 224, 38, 1414, 265, 6060, 6, 59, 830, 223, 1215, 699, 26, 4811, 5, 1], [71, 21641, 2642, 646, 1067, 46, 11529, 3450, 828, 16, 5727, 27874, 65, 118, 10126, 3, 9, 3534, 9, 226, 5, 1]]}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-b70jh26IrJS",
        "outputId": "f5a50c2b-0106-41bb-ffb2-3a7c445308f6"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "接下来对数据集datasets里面的所有样本进行预处理，处理的方式是使用map函数，将预处理函数prepare_train_features应用到（map)所有样本上。"
      ],
      "metadata": {
        "id": "zS-6iXTkIrJT"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)"
      ],
      "outputs": [],
      "metadata": {
        "id": "DDtsaJeVIrJT"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "更好的是，返回的结果会自动被缓存，避免下次处理的时候重新计算（但是也要注意，如果输入有改动，可能会被缓存影响！）。datasets库函数会对输入的参数进行检测，判断是否有变化，如果没有变化就使用缓存数据，如果有变化就重新处理。但如果输入参数不变，想改变输入的时候，最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外，上面使用到的`batched=True`这个参数是tokenizer的特点，以为这会使用多线程同时并行对输入进行处理。"
      ],
      "metadata": {
        "id": "voWiw8C7IrJV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 微调模型"
      ],
      "metadata": {
        "id": "545PP3o8IrJV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "既然数据已经准备好了，现在我们需要下载并加载我们的预训练模型，然后微调预训练模型。既然我们是做seq2seq任务，那么我们需要一个能解决这个任务的模型类。我们使用`AutoModelForSeq2SeqLM`这个类。和tokenizer相似，`from_pretrained`方法同样可以帮助我们下载并加载模型，同时也会对模型进行缓存，就不会重复下载模型啦。"
      ],
      "metadata": {
        "id": "FBiW8UpKIrJW"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
        "\n",
        "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
      ],
      "outputs": [],
      "metadata": {
        "id": "TlqNaB8jIrJW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "由于我们微调的任务是seq2seq任务，而我们加载的是预训练的seq2seq模型，所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数（比如：预训练语言模型的神经网络head被扔掉了，同时随机初始化了机器翻译的神经网络head）。"
      ],
      "metadata": {
        "id": "CczA5lJlIrJX"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "为了能够得到一个`Seq2SeqTrainer`训练工具，我们还需要3个要素，其中最重要的是训练的设定/参数[`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性"
      ],
      "metadata": {
        "id": "_N8urzhyIrJY"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "batch_size = 16\n",
        "args = Seq2SeqTrainingArguments(\n",
        "    \"test-summarization\",\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=batch_size,\n",
        "    per_device_eval_batch_size=batch_size,\n",
        "    weight_decay=0.01,\n",
        "    save_total_limit=3,\n",
        "    num_train_epochs=1,\n",
        "    predict_with_generate=True,\n",
        "    fp16=True,\n",
        ")"
      ],
      "outputs": [],
      "metadata": {
        "id": "Bliy8zgjIrJY"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "上面evaluation_strategy = \"epoch\"参数告诉训练代码：我们每个epcoh会做一次验证评估。\n",
        "\n",
        "上面batch_size在这个notebook之前定义好了。\n",
        "\n",
        "由于我们的数据集比较大，同时`Seq2SeqTrainer`会不断保存模型，所以我们需要告诉它至多保存`save_total_limit=3`个模型。\n",
        "\n",
        "最后我们需要一个数据收集器data collator，将我们处理好的输入喂给模型。"
      ],
      "metadata": {
        "id": "km3pGVdTIrJc"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
      ],
      "outputs": [],
      "metadata": {
        "id": "osCjNg41dJT_"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "设置好`Seq2SeqTrainer`还剩最后一件事情，那就是我们需要定义好评估方法。我们使用`metric`来完成评估。将模型预测送入评估之前，我们也会做一些数据后处理："
      ],
      "metadata": {
        "id": "7sZOdRlRIrJd"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "import nltk\n",
        "import numpy as np\n",
        "\n",
        "def compute_metrics(eval_pred):\n",
        "    predictions, labels = eval_pred\n",
        "    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
        "    # Replace -100 in the labels as we can't decode them.\n",
        "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
        "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
        "    \n",
        "    # Rouge expects a newline after each sentence\n",
        "    decoded_preds = [\"\\n\".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]\n",
        "    decoded_labels = [\"\\n\".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]\n",
        "    \n",
        "    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n",
        "    # Extract a few results\n",
        "    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n",
        "    \n",
        "    # Add mean generated length\n",
        "    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]\n",
        "    result[\"gen_len\"] = np.mean(prediction_lens)\n",
        "    \n",
        "    return {k: round(v, 4) for k, v in result.items()}"
      ],
      "outputs": [],
      "metadata": {
        "id": "UmvbnJ9JIrJd"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "最后将所有的参数/数据/模型传给`Seq2SeqTrainer`即可"
      ],
      "metadata": {
        "id": "rXuFTAzDIrJe"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "trainer = Seq2SeqTrainer(\n",
        "    model,\n",
        "    args,\n",
        "    train_dataset=tokenized_datasets[\"train\"],\n",
        "    eval_dataset=tokenized_datasets[\"validation\"],\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_metrics\n",
        ")"
      ],
      "outputs": [],
      "metadata": {
        "id": "imY1oC3SIrJf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "调用`train`方法进行微调训练。"
      ],
      "metadata": {
        "id": "CdzABDVcIrJg"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "trainer.train()"
      ],
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "        <style>\n",
              "            /* Turns off some styling */\n",
              "            progress {\n",
              "                /* gets rid of default border in Firefox and Opera. */\n",
              "                border: none;\n",
              "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
              "                background-size: auto;\n",
              "            }\n",
              "        </style>\n",
              "      \n",
              "      <progress value='12753' max='12753' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [12753/12753 1:21:48, Epoch 1/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Rouge1</th>\n",
              "      <th>Rouge2</th>\n",
              "      <th>Rougel</th>\n",
              "      <th>Rougelsum</th>\n",
              "      <th>Gen Len</th>\n",
              "      <th>Runtime</th>\n",
              "      <th>Samples Per Second</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>2.721100</td>\n",
              "      <td>2.479327</td>\n",
              "      <td>28.300900</td>\n",
              "      <td>7.721100</td>\n",
              "      <td>22.243000</td>\n",
              "      <td>22.249600</td>\n",
              "      <td>18.822500</td>\n",
              "      <td>326.333800</td>\n",
              "      <td>34.725000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TrainOutput(global_step=12753, training_loss=2.7692033505520146, metrics={'train_runtime': 4909.3835, 'train_samples_per_second': 2.598, 'total_flos': 7.774481450954342e+16, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 335248, 'init_mem_gpu_alloc_delta': 242026496, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 2637782, 'train_mem_gpu_alloc_delta': 728138240, 'train_mem_cpu_peaked_delta': 138226182, 'train_mem_gpu_peaked_delta': 14677017088})"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ],
      "metadata": {
        "id": "uNx5pyRlIrJh",
        "outputId": "077e661e-d36c-469b-89b8-7ff7f73541ec",
        "scrolled": false
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "最后别忘了，查看如何上传模型 ，上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样，直接用模型名字就能使用您的模型啦。\n"
      ],
      "metadata": {
        "id": "ha0-YirPdJUB"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [],
      "outputs": [],
      "metadata": {
        "id": "NC7aCpkBdJUB"
      }
    }
  ],
  "metadata": {
    "colab": {
      "name": "4.7-生成任务-摘要生成",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}